InΒ [26]:
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from matplotlib import patches
from torchvision import io
import torch
import numpy as np
from CONFIG import config
from datalib import build_data_loader, load_data
from utils.utils import count_model_params, load_model
from torchinfo import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = config['data']['dataset_path']
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

- Data VisualizationΒΆ

  • Here, we investigate the transforms on train dataset. We compare the data in a sequence before and after the transformations being applied.

  • The transforms are implemented to equally to all the samples in a sequence, preserving the time consistanct of the video sequence.

  • Each new sequence gets fresh random decisions for augmentation. This is done in MoviC.py file line 49 (self.transforms.reset_sequence(sequence_idx=idx), num_epochs=self.num_epochs) resets the flags for horizontal and vertical flips.

  • For the sake of having the same transformation on all data in a sequence, we needed to have different seeds per sequence. But at the same time we need a base/stable seed during training the model and initializeing the tensors. To solve this problem, we considered the seed for training as base_seed which is increased by idx - the index of the sequence in the dataset - and also a random value in range(epoch). This ensures both consistancy during data augmentatin per sequence, diversity per sequence/epoch, and also during training procecss for tensors (having a defined base-seed).( This is done in _make_sequence_decisions() function line 41 in utils.transforms.py). Although this approach provides deterministicc results for agumentation per_sequence, is it non-deterministic per run!

  • For each sequence, the choosen augmentation can be aither vertical, horizontal, both, or neither! This is also done in _make_sequence_decisions() function in MoviC.py. We consider independent probabilities for each augmentation:

    • should_hflip = random.random() < 0.3: 30% chance of horizontal flip.
    • should_vflip = random.random() < 0.7: 70% chance of vertical flip.

    Since these are independent, the possible outcomes per sequence are:

    • Neither: (1-0.3) * (1-0.7) = 70% * 30% = 21% probability.

    • Horizontal only: 30% * 30% = 9% probability.

    • Vertical only: 70% * 70% = 49% probability.

    • Both: 30% * 70% = 21% probability.

InΒ [Β ]:
# Apply transforms to ENTIRE sequence using transform pipeline

train_dataset = load_data(path, split='train', use_transforms=False, Visualize=True)
train_dataset_transformed = load_data(path, split='train', use_transforms=True, Visualize=True)
[INFO] - TRAIN Data Loaded: Coordinates: 9737, Masks: 9737, RGB videos:  9737, Flows:  9737
[INFO] - TRAIN Data Loaded: Coordinates: 9737, Masks: 9737, RGB videos:  9737, Flows:  9737

First run: only vertical flip

InΒ [66]:
from utils.visualization import plot_transform_comparison


idx = np.random.randint(0, len(train_dataset))

rgbs_orig, masks_orig, flows_orig, coords_orig = train_dataset[idx]
rgbs_trans, masks_trans, flows_trans, coords_trans = train_dataset_transformed[idx]


plot_transform_comparison(
                        rgbs_orig, masks_orig, flows_orig, coords_orig,
                        rgbs_trans, masks_trans, flows_trans, coords_trans,
                        n_rows=6, sequence_idx=idx
                        )
No description has been provided for this image

Second run: Both horizontal and vertical flips

InΒ [Β ]:
from utils.visualization import plot_transform_comparison

idx = np.random.randint(0, len(train_dataset))

rgbs_orig, masks_orig, flows_orig, coords_orig = train_dataset[idx]
rgbs_trans, masks_trans, flows_trans, coords_trans = train_dataset_transformed[idx]


plot_transform_comparison(
                        rgbs_orig, masks_orig, flows_orig, coords_orig,
                        rgbs_trans, masks_trans, flows_trans, coords_trans,
                        n_rows=6, sequence_idx=idx
                        )
No description has been provided for this image

- Datalloaders and modality shapesΒΆ

InΒ [27]:
val_dataset = load_data(path, split='validation', use_transforms=True)

# train_loader= build_data_loader(train_dataset, split='train')
val_loader = build_data_loader(val_dataset, split='validation')
[INFO] - VALIDATION Data Loaded: Coordinates: 250, Masks: 250, RGB videos:  250, Flows:  250
InΒ [28]:
## Verifying the dataloader  
rgbs, masks, flows, coords = next(iter(val_loader))

# Send all tensors to device
rgbs = rgbs.to(device)
flows = flows.to(device)

# Move all mask tensors to device
for k in masks:
    masks[k] = masks[k].to(device)

# Move all coords tensors to device
for k in coords:
    coords[k] = coords[k].to(device)

print(f"RGBs shape: {rgbs.shape}\nFlows shape: {flows.shape}\nMasks shape: {masks['masks'].shape} \nCoords com shape: {coords['com'].shape}\nCoords bbxs shape: {coords['bbox'].shape}")
RGBs shape: torch.Size([32, 24, 3, 128, 128])
Flows shape: torch.Size([32, 24, 3, 128, 128])
Masks shape: torch.Size([32, 24, 128, 128]) 
Coords com shape: torch.Size([32, 24, 11, 2])
Coords bbxs shape: torch.Size([32, 24, 11, 4])

- Utility visualizationΒΆ

Here, we will visulize some of the helper fucntions used during training

  1. During object-centric scene representation learning, each token is an object image. Therefore ,we need to extract object images from frames. We can achieve this goal using either bboxs or mask labels. First we will exploer extracting object frames from masks:
InΒ [35]:
'''
Used during training the object-centric model. when masks labels are used to extract object frames from one image. 
The object frames are then used to guide the prediction of the next frame.
'''

def extract_object_specific_frames_from_masks(images, masks, num_objects):
    """
    images: Tensor of shape [B, T, C, H, W]
    masks: Tensor of shape [B, T, H, W] with int values from 0 to num_objects-1
    num_objects: int, number of unique objects (including background if needed)

    Returns:
        object_frames: Tensor of shape [B, T, num_objects, C, H, W]
    """
    B, T, C, H, W = images.shape
    device = images.device

    # Expand images for each object
    object_frames = torch.zeros(B, T, num_objects, C, H, W, device=device, dtype=images.dtype)

    for obj_id in range(num_objects):
        # Create mask for this object: shape [B, T, 1, H, W]
        obj_mask = (masks == obj_id).unsqueeze(2)  # [B, T, 1, H, W]
        # Broadcast mask to all channels
        obj_mask = obj_mask.expand(-1, -1, C, -1, -1)  # [B, T, C, H, W]
        # Apply mask
        object_frames[:, :, obj_id] = images * obj_mask

    return object_frames

num_objects = 11
object_frames = extract_object_specific_frames_from_masks(rgbs, masks['masks'], num_objects)
object_frames.shape
Out[35]:
torch.Size([32, 24, 11, 3, 128, 128])
  • As we can see, we extracted 11 (one background + 10 objects in Movi-C dataset) different "object_frames" from one image. No we visualize them:
InΒ [36]:
from utils.visualization import plot_object_frames

# Choose the first batch and time step
batch_idx = 3
seq_idx = 11

# object_frames shape: [B, T, num_objects, C, H, W]
# We'll print all object frames for this batch and time step
num_objects = len(object_frames[batch_idx][seq_idx])

plot_object_frames(rgbs, object_frames, batch_idx, seq_idx)
No description has been provided for this image
No description has been provided for this image
  1. Now we investigate extracting object frames from bounding boxes:
InΒ [33]:
def extract_object_specific_frames_from_bboxes(images, bboxes):
    """
    images: Tensor of shape [B, T, C, H, W]
    bboxes: Tensor of shape [B, T, num_objects, 4] (x1, y1, x2, y2) in pixel coordinates

    Returns:
        object_frames: Tensor of shape [B, T, num_objects, C, H, W]
    """
    B, T, C, H, W = images.shape
    device = images.device
    num_objects = bboxes.shape[2]

    # Prepare output tensor
    object_frames = torch.zeros(B, T, num_objects, C, H, W, device=device, dtype=images.dtype)

    for obj_id in range(num_objects):
        for b in range(B):
            for t in range(T):
                x1, y1, x2, y2 = bboxes[b, t, obj_id]
                # Clamp coordinates to image bounds and convert to int
                x1 = int(torch.clamp(x1, 0, W-1).item())
                y1 = int(torch.clamp(y1, 0, H-1).item())
                x2 = int(torch.clamp(x2, 0, W-1).item())
                y2 = int(torch.clamp(y2, 0, H-1).item())
                # Ensure valid bbox
                if x2 > x1 and y2 > y1:
                    # Copy the region from the image to the corresponding location in object_frames
                    object_frames[b, t, obj_id, :, y1:y2, x1:x2] = images[b, t, :, y1:y2, x1:x2]
                # else: leave as zeros (background)
    return object_frames

object_frames = extract_object_specific_frames_from_bboxes(rgbs, coords['bbox'])
object_frames.shape
Out[33]:
torch.Size([32, 24, 11, 3, 128, 128])
InΒ [34]:
plot_object_frames(rgbs, object_frames, batch_idx, seq_idx)
No description has been provided for this image
No description has been provided for this image

PatchifierΒΆ

Another utility function to patchify input images. Only used in the Holistic scene representation training where each patch is considered as a token.

InΒ [10]:
from model.model_utils import Patchifier

BATCH_IDX = 4
seq_len = 5
img = rgbs[BATCH_IDX, seq_len]

plt.figure(figsize=(3, 3))
plt.imshow(img.permute(1, 2, 0).cpu().numpy())
plt.axis("off")

patch_size = config['data']['patch_size'] # num of H and W pixels of each patch
patchifier = Patchifier(patch_size)
patch_data = patchifier(rgbs)
print(f"Patchified Shape: {patch_data.shape}") # (B, seq_len, num_patch_H * num_patch_W, 3 * 32 * 32)

num_patches = patch_data.shape[2] # num_patches = num_patch_H * num_patch_W
print(f"Number of patches: {num_patches}")
print(f"Patch size: {patch_size}")
fig, ax = plt.subplots(1, num_patches)
fig.set_size_inches(3 * num_patches, 3)
for i in range(num_patches):
    cur_patch = patch_data[BATCH_IDX, seq_len, i].reshape(3, patch_size, patch_size)
    ax[i].imshow(cur_patch.permute(1, 2, 0).cpu().numpy())
    ax[i].set_title(f"Patch {i+1}")
    ax[i].axis("off")

plt.show()
Patchified Shape: torch.Size([32, 24, 64, 768])
Number of patches: 64
Patch size: 16
No description has been provided for this image
No description has been provided for this image

ExperimentsΒΆ

All the experiments were run on different servers and machines for uni-bonn. In particular, we used cuda1, cuda2, cuda3, cuda4, cuda6 machines, each of which with one NVIDIA GeForce RTX 4090/3090 processing GPU with 24 gigabites of memory.

1. Holistic scene representationΒΆ

- Holistic Transformer-AutoEncoder ModuleΒΆ

InΒ [4]:
from model.holistic_encoder import HolisticEncoder
from model.holistic_decoder import HolisticDecoder
from model.ocvp import TransformerAutoEncoder
  • We used Full-Transformer based Autoencoder for the Holistic scene scenario. Each image was patfchified into tokens, added positional embeddings, went through transoformer blocks, and normalized to get the desired embeddigns. During the training process, we experienced different transformer architectures, along with input image sizes (64*64)- right table with moderate parameters,called base- and (128*128) - left table with larget transformer parameters called XL. The configs were as following:

Large model configs and learning process

ParameterValue
model_name02_Holistic_AE_XL
batch_size32
patch_size16
num_workers8
num_epochs100
warmup_epochs5
early_stopping_patience10
lr0.001
encoder_embed_dim512
decoder_embed_dim384
max_len64
in_out_channels3
attn_dim128
num_heads8
mlp_size1024
encoder_depth12
decoder_depth8
predictor_depth8
predictor_embed_dim256
residualtrue

Based model configs and learning process

ParameterValue
model_name01_Holistic_AE_Base
batch_size32
patch_size16
num_workers8
num_epochs100
warmup_epochs5
early_stopping_patience10
lr0.0002
encoder_embed_dim128
decoder_embed_dim64
max_len64
in_out_channels3
attn_dim64
num_heads8
mlp_size512
encoder_depth6
decoder_depth3
predictor_depth4
predictor_embed_dim128
residualtrue

Image 1 Image 2

Loss curves for different experimments:ΒΆ

Due to the experiments being run in different machines, the tensorboard logs from the machines related to that experiment will be provided here.

Image 1 Image 2

Note:

our base model also could learn proparly with far less parameters, but in the end the quality of recons with the XL model were better and sharper. So we proceded with this model for predictor training

InΒ [5]:
holistic_encoder = HolisticEncoder()

holistic_decoder = HolisticDecoder()
model = TransformerAutoEncoder(holistic_encoder, holistic_decoder).to(device)

summary(model, input_size= rgbs.shape)
Out[5]:
====================================================================================================
Layer (type:depth-idx)                             Output Shape              Param #
====================================================================================================
TransformerAutoEncoder                             [8, 24, 3, 128, 128]      --
β”œβ”€HolisticEncoder: 1-1                             [8, 24, 64, 512]          --
β”‚    └─Sequential: 2-1                             [8, 24, 64, 512]          --
β”‚    β”‚    └─LayerNorm: 3-1                         [8, 24, 64, 768]          1,536
β”‚    β”‚    └─Linear: 3-2                            [8, 24, 64, 512]          393,728
β”‚    └─PositionalEncoding: 2-2                     [8, 24, 64, 512]          --
β”‚    └─Sequential: 2-3                             [8, 24, 64, 512]          --
β”‚    β”‚    └─TransformerBlock: 3-3                  [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-4                  [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-5                  [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-6                  [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-7                  [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-8                  [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-9                  [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-10                 [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-11                 [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-12                 [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-13                 [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-14                 [8, 24, 64, 512]          1,314,304
β”‚    └─LayerNorm: 2-4                              [8, 24, 64, 512]          1,024
β”œβ”€HolisticDecoder: 1-2                             [8, 24, 3, 128, 128]      --
β”‚    └─Linear: 2-5                                 [8, 24, 64, 384]          196,992
β”‚    └─PositionalEncoding: 2-6                     [8, 24, 64, 384]          --
β”‚    └─Sequential: 2-7                             --                        --
β”‚    β”‚    └─TransformerBlock: 3-15                 [8, 24, 64, 384]          985,984
β”‚    β”‚    └─TransformerBlock: 3-16                 [8, 24, 64, 384]          985,984
β”‚    β”‚    └─TransformerBlock: 3-17                 [8, 24, 64, 384]          985,984
β”‚    β”‚    └─TransformerBlock: 3-18                 [8, 24, 64, 384]          985,984
β”‚    β”‚    └─TransformerBlock: 3-19                 [8, 24, 64, 384]          985,984
β”‚    β”‚    └─TransformerBlock: 3-20                 [8, 24, 64, 384]          985,984
β”‚    β”‚    └─TransformerBlock: 3-21                 [8, 24, 64, 384]          985,984
β”‚    β”‚    └─TransformerBlock: 3-22                 [8, 24, 64, 384]          985,984
β”‚    └─LayerNorm: 2-8                              [8, 24, 64, 384]          768
β”‚    └─Linear: 2-9                                 [8, 24, 64, 768]          295,680
====================================================================================================
Total params: 24,549,248
Trainable params: 24,549,248
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 413.45
====================================================================================================
Input size (MB): 37.75
Forward/backward pass size (MB): 6719.28
Params size (MB): 98.20
Estimated Total Size (MB): 6855.22
====================================================================================================
InΒ [8]:
# Full forward pass through the model

with torch.no_grad():
    encoded_features = holistic_encoder(rgbs)
    print("Encoded Features shape:", encoded_features.shape)
    recons, loss = holistic_decoder(encoded_features)
    print("Reconstructed image shape:", recons.shape)
print(f"Reconstructed images match the original images shape: {recons.shape == rgbs.shape}")
Encoded Features shape: torch.Size([8, 24, 16, 512])
Reconstructed image shape: torch.Size([8, 24, 3, 64, 64])
Reconstructed images match the original images shape: True

- Loading Holisitc-AE model pre-trained checkpointsΒΆ

1. The Xl modelΒΆ

Now we will evaluate the ability of our pre-trained autoencoders by running a full-forward pass through network using eval images and visualize the results for some sequences.

InΒ [12]:
path_AE = 'experiments/02_Holistic_AE_XL/checkpoints/best_02_Holistic_AE_XL.pth'
model = load_model(model, mode='AE_inference', path_AE = path_AE)
InΒ [Β ]:
from utils.visualization import plot_images_vs_recons

# Generate reconstructions
with torch.no_grad():
    recons, _ = model(rgbs)

# Plot 5 random sequences, each showing 8 original vs 8 reconstructed frames
plot_images_vs_recons(rgbs, recons)
No description has been provided for this image
  • NOTE :

As we can see, our autoencoder has efficciently learnt the latent space of the input images and can reconstruct/map any embedding in the given latent space to it's relevant input image.

- PSNR, SSIM, and LPIPS metrics analysis:ΒΆ
InΒ [Β ]:
from utils.metrics import evaluate_metrics

evaluate_metrics(recons, rgbs)
PSNR Mean: 26.412456512451172
PSNR Framewise: tensor([26.7287, 26.6270, 27.3801, 26.3104, 26.7908, 27.1606, 26.2990, 26.9788,
        26.5026, 26.6112, 26.3883, 26.8132, 26.4445, 25.7483, 26.5985, 25.8677,
        25.6586, 25.8612, 26.2204, 26.4984, 25.9288, 26.0611, 26.1334, 26.2871],
       device='cuda:0')
SSIM Mean: 0.9157991409301758
SSIM Framewise: tensor([0.9209, 0.9205, 0.9219, 0.9152, 0.9161, 0.9146, 0.9131, 0.9188, 0.9141,
        0.9140, 0.9158, 0.9166, 0.9153, 0.9162, 0.9202, 0.9115, 0.9147, 0.9133,
        0.9133, 0.9164, 0.9152, 0.9152, 0.9133, 0.9129], device='cuda:0')
LPIPS Mean: 0.03445807844400406
LPIPS Framewise: tensor([0.0360, 0.0340, 0.0341, 0.0379, 0.0334, 0.0351, 0.0360, 0.0341, 0.0329,
        0.0335, 0.0350, 0.0349, 0.0343, 0.0336, 0.0350, 0.0368, 0.0354, 0.0363,
        0.0318, 0.0322, 0.0329, 0.0339, 0.0340, 0.0338], device='cuda:0')
2- The base modelΒΆ
InΒ [13]:
path_AE = 'experiments/01_Holistic_AE_Base/checkpoints/best_01_Holistic_AE_Base.pth'
model = load_model(model, mode='AE_inference', path_AE = path_AE)
summary(model, input_size= rgbs.shape)
Out[13]:
====================================================================================================
Layer (type:depth-idx)                             Output Shape              Param #
====================================================================================================
TransformerAutoEncoder                             [8, 24, 3, 128, 128]      --
β”œβ”€HolisticEncoder: 1-1                             [8, 24, 64, 128]          --
β”‚    └─Sequential: 2-1                             [8, 24, 64, 128]          --
β”‚    β”‚    └─LayerNorm: 3-1                         [8, 24, 64, 768]          1,536
β”‚    β”‚    └─Linear: 3-2                            [8, 24, 64, 128]          98,432
β”‚    └─PositionalEncoding: 2-2                     [8, 24, 64, 128]          --
β”‚    └─Sequential: 2-3                             [8, 24, 64, 128]          --
β”‚    β”‚    └─TransformerBlock: 3-3                  [8, 24, 64, 128]          164,992
β”‚    β”‚    └─TransformerBlock: 3-4                  [8, 24, 64, 128]          164,992
β”‚    β”‚    └─TransformerBlock: 3-5                  [8, 24, 64, 128]          164,992
β”‚    β”‚    └─TransformerBlock: 3-6                  [8, 24, 64, 128]          164,992
β”‚    β”‚    └─TransformerBlock: 3-7                  [8, 24, 64, 128]          164,992
β”‚    β”‚    └─TransformerBlock: 3-8                  [8, 24, 64, 128]          164,992
β”‚    └─LayerNorm: 2-4                              [8, 24, 64, 128]          256
β”œβ”€HolisticDecoder: 1-2                             [8, 24, 3, 128, 128]      --
β”‚    └─Linear: 2-5                                 [8, 24, 64, 64]           8,256
β”‚    └─PositionalEncoding: 2-6                     [8, 24, 64, 64]           --
β”‚    └─Sequential: 2-7                             --                        --
β”‚    β”‚    └─TransformerBlock: 3-9                  [8, 24, 64, 64]           82,752
β”‚    β”‚    └─TransformerBlock: 3-10                 [8, 24, 64, 64]           82,752
β”‚    β”‚    └─TransformerBlock: 3-11                 [8, 24, 64, 64]           82,752
β”‚    └─LayerNorm: 2-8                              [8, 24, 64, 64]           128
β”‚    └─Linear: 2-9                                 [8, 24, 64, 768]          49,920
====================================================================================================
Total params: 1,396,736
Trainable params: 1,396,736
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 22.48
====================================================================================================
Input size (MB): 37.75
Forward/backward pass size (MB): 1189.09
Params size (MB): 5.59
Estimated Total Size (MB): 1232.42
====================================================================================================
InΒ [9]:
from utils.visualization import plot_images_vs_recons

# Generate reconstructions
with torch.no_grad():
    recons, _ = model(rgbs)

# Plot 5 random sequences, each showing 8 original vs 8 reconstructed frames
plot_images_vs_recons(rgbs, recons)
No description has been provided for this image
- PSNR, SSIM, and LPIPS metrics analysis:ΒΆ
InΒ [12]:
from utils.metrics import evaluate_metrics

evaluate_metrics(recons, rgbs)
PSNR Mean: 21.13724136352539
PSNR Framewise: tensor([22.7105, 23.3850, 22.2196, 21.7082, 20.2734, 20.9578, 20.5842, 20.9267,
        20.5974, 20.4753, 21.1178, 21.7913, 21.4562, 20.7999, 20.8086, 21.4605,
        21.2202, 20.4990, 20.5013, 20.6580, 20.8458, 20.3676, 20.8827, 21.0468],
       device='cuda:0')
SSIM Mean: 0.7252386808395386
SSIM Framewise: tensor([0.7582, 0.7554, 0.7433, 0.7292, 0.7126, 0.7124, 0.7161, 0.7119, 0.7105,
        0.7053, 0.7140, 0.7222, 0.7255, 0.7237, 0.7231, 0.7261, 0.7281, 0.7220,
        0.7208, 0.7227, 0.7281, 0.7289, 0.7306, 0.7350], device='cuda:0')
LPIPS Mean: 0.3330411911010742
LPIPS Framewise: tensor([0.3265, 0.3193, 0.3220, 0.3287, 0.3255, 0.3414, 0.3321, 0.3354, 0.3517,
        0.3403, 0.3370, 0.3318, 0.3359, 0.3366, 0.3474, 0.3420, 0.3471, 0.3330,
        0.3285, 0.3270, 0.3214, 0.3293, 0.3256, 0.3276], device='cuda:0')

- Performance analysis and comparisonΒΆ


  • PSNR (Peak Signal-to-Noise Ratio):
    Measures how close your prediction is to the ground truth at the pixel level.
    Higher is better (means less error/noise).

  • SSIM (Structural Similarity Index):
    Measures how well your prediction preserves the structure and details of the ground truth image.
    Higher is better (closer to 1 means almost identical structure).

  • LPIPS (Learned Perceptual Image Patch Similarity):
    Measures perceptual similarity using a neural networkβ€”closer to how humans judge images.
    Lower is better (0 means visually identical to a human).


Metric Description Higher/Lower is better XL Model Base Model Which is better
PSNR Pixel-wise fidelity Higher 26.41 21.14 XL
SSIM Structural similarity Higher 0.916 0.725 XL
LPIPS Perceptual similarity Lower 0.0345 0.3330 XL

As we can see, the XL model is significantly superior to the Base model in pixel fidelity, structural integrity, and perceptual similarity, producing much more realistic and accurate predictions.

- Holositic Transformer-Predictor ModuleΒΆ

We choose the XL model due to better perfomance and trained the predictor in the sencond phase of training.

InΒ [9]:
from model.holistic_predictor import HolisticTransformerPredictor
from model.predictor_wrapper import PredictorWrapper
from model.ocvp import TransformerPredictor
InΒ [10]:
holistic_predictor = HolisticTransformerPredictor()
holistic_predictor= PredictorWrapper(holistic_predictor)
model = TransformerPredictor(holistic_encoder, holistic_decoder, holistic_predictor, mode='inference').to(device)

summary(model, input_size= rgbs.shape)
Out[10]:
=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
TransformerPredictor                                    [8, 5, 64, 512]           8,381,312
β”œβ”€HolisticEncoder: 1-1                                  [8, 24, 64, 512]          --
β”‚    └─Sequential: 2-1                                  [8, 24, 64, 512]          --
β”‚    β”‚    └─LayerNorm: 3-1                              [8, 24, 64, 768]          1,536
β”‚    β”‚    └─Linear: 3-2                                 [8, 24, 64, 512]          393,728
β”‚    └─PositionalEncoding: 2-2                          [8, 24, 64, 512]          --
β”‚    └─Sequential: 2-3                                  [8, 24, 64, 512]          --
β”‚    β”‚    └─TransformerBlock: 3-3                       [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-4                       [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-5                       [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-6                       [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-7                       [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-8                       [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-9                       [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-10                      [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-11                      [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-12                      [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-13                      [8, 24, 64, 512]          1,314,304
β”‚    β”‚    └─TransformerBlock: 3-14                      [8, 24, 64, 512]          1,314,304
β”‚    └─LayerNorm: 2-4                                   [8, 24, 64, 512]          1,024
β”œβ”€PredictorWrapper: 1-2                                 [8, 5, 64, 512]           --
β”‚    └─HolisticTransformerPredictor: 2-5                [8, 5, 64, 512]           --
β”‚    β”‚    └─Linear: 3-15                                [8, 5, 64, 256]           131,328
β”‚    β”‚    └─PositionalEncoding: 3-16                    [8, 5, 64, 256]           --
β”‚    β”‚    └─Sequential: 3-17                            [8, 5, 64, 256]           5,261,312
β”‚    β”‚    └─LayerNorm: 3-18                             [8, 5, 64, 256]           512
β”‚    β”‚    └─Linear: 3-19                                [8, 5, 64, 512]           131,584
β”‚    └─HolisticTransformerPredictor: 2-6                [8, 5, 64, 512]           (recursive)
β”‚    β”‚    └─Linear: 3-20                                [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─PositionalEncoding: 3-21                    [8, 5, 64, 256]           --
β”‚    β”‚    └─Sequential: 3-22                            [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─LayerNorm: 3-23                             [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─Linear: 3-24                                [8, 5, 64, 512]           (recursive)
β”‚    └─HolisticTransformerPredictor: 2-7                [8, 5, 64, 512]           (recursive)
β”‚    β”‚    └─Linear: 3-25                                [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─PositionalEncoding: 3-26                    [8, 5, 64, 256]           --
β”‚    β”‚    └─Sequential: 3-27                            [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─LayerNorm: 3-28                             [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─Linear: 3-29                                [8, 5, 64, 512]           (recursive)
β”‚    └─HolisticTransformerPredictor: 2-8                [8, 5, 64, 512]           (recursive)
β”‚    β”‚    └─Linear: 3-30                                [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─PositionalEncoding: 3-31                    [8, 5, 64, 256]           --
β”‚    β”‚    └─Sequential: 3-32                            [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─LayerNorm: 3-33                             [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─Linear: 3-34                                [8, 5, 64, 512]           (recursive)
β”‚    └─HolisticTransformerPredictor: 2-9                [8, 5, 64, 512]           (recursive)
β”‚    β”‚    └─Linear: 3-35                                [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─PositionalEncoding: 3-36                    [8, 5, 64, 256]           --
β”‚    β”‚    └─Sequential: 3-37                            [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─LayerNorm: 3-38                             [8, 5, 64, 256]           (recursive)
β”‚    β”‚    └─Linear: 3-39                                [8, 5, 64, 512]           (recursive)
=========================================================================================================
Total params: 30,073,984
Trainable params: 30,073,984
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 536.98
=========================================================================================================
Input size (MB): 37.75
Forward/backward pass size (MB): 6350.18
Params size (MB): 86.77
Estimated Total Size (MB): 6474.70
=========================================================================================================

Loss curves for different experimments:ΒΆ

Image 1 Image 2

  • We tried many different combinations of experminets for the predictor. The model architecture and embedding sizes of the whole network needed to be the same for predictor training. So we were kind of enforced to stick to the initial model configs for AE during the predictor training. We choose the "best_03_Holistic_Predictor_XL" model which could reach the best results for the sake of visuzalization here

Loading model pre-trained Holistic-Predictor checkpointsΒΆ

InΒ [22]:
path_AE = 'experiments/02_Holistic_AE_XL/checkpoints/best_02_Holistic_AE_XL.pth'
path_predictor = 'experiments/03_Holistic_Predictor_XL/checkpoints/best_03_Holistic_Predictor_XL.pth'

model = load_model(model, mode='inference', path_AE = path_AE, path_predictor=path_predictor)
InΒ [Β ]:
from utils.visualization import plot_predictor_images

# we Visualize 3 random sequences
idx = [0,3,7]

with torch.no_grad():
    for i in idx:
        encoded_features = holistic_encoder(rgbs)
        preds, loss, input_range, target_range = model(rgbs[i].unsqueeze(0))

        recons, _ = model.decoder(preds)
        input_images = rgbs[i,input_range[0]:input_range[1]].unsqueeze(0)
        target_images = rgbs[i,target_range[0]:target_range[1]].unsqueeze(0)
        plot_predictor_images(input_images, target_images, recons)
        
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

NoteΒΆ

As we can see, our predictor model could do better in predicting the future frames and providing more meaningful embedding for the decoder module. As our pre-trained Ae module was robust, we believe that this fully related to the predictor module. The results are the best we could get with the current network architecture. Maybe if we considered more deeper transformer blocks for the predictor, on smaller embedding size, out predictor could produce more meaningful predictions. However, given no information from objects in Holistic scenario, we could not expect the predictor to make perfect predictions (we can expect this from our object scentric predictor), but at least we could expect better reconstructions and more meaningful embeddings. However, this needed an end-to-end training with new configs.

2. Object-Centric Scene RepresentationΒΆ

  • Object-Centric Transformer-AutoEncoder Module

In this scenario, instead of image patches, each extracted object_frame (as explained above) would be a token.

We ran many experiments for this case!! The fist set of experiments were with linear layers as the encoder input (to get the input embeddings for the transformer module) and also in the decoder output (to reconstruct the input image from decoder embedding). The linear layers impose a large number of parameters to the network! But they are faster to train! Following are our results for OC-AE using this setup

Image 2

OC-AE XL 64 Linear OC-AE XL 64 Advanced 2 Linear

OC-AE XL 64 Linear (repeat) OC-AE XL 64 Advanced 3 Linear

NOTEΒΆ

As we can see the recons were not the best and they remain blurry. We believe this was due to the large compression rate in the encoder layer (say : 3*64*64 ---> encoder_embedding_dim) and also the decompression in the decoder (decoder_embed_dim ---> 3*64*64) which introduced a lot of parameters (nearly 800M parameters needed to do the job!!!!) to the model and made the convergance hard!

For this reason, it made a lot of sense to replace those layers with a respective CNN-Based network. Therefor, we did furture experimens with the hybrid Transformer-CNN architecture network. The results are as follows:

OC-AE XL 64 Hybrid 1 OC-AE XL 64 Hybrid 2

OC-AE XL 64 Full CNN OC-AE XL 64 Mixed CNN Decoder Linear Encoder

OC-AE XL 64 Full CNN 64 OC-AE XL 64 Full CNN New Module

NoteΒΆ

Although we tried many different CNN architectures and experiments- from various simple to complex cnn encoder-decoders to using a mixed combintation of linear encoder and cnn decoder to get better recons- we were unsuccesful get the best recons using CNN networks!! We mostly used mask data labels to extract the object frames. We also tried more expressive loss funtions (mixture of MSE+L1 loss). We found it so hard to reach some well-working combination of cnn and transformer networks for the OC-AE task. For this reason, we proceeded we our best so far OC-AE model. This model had the following configs:

ParameterValue
model_name01_OC_AE_XL_64_Full_CNN
batch_size32
patch_size16
max_objects11
image_height64
image_width64
num_epochs100
warmup_epochs15
early_stopping_patience15
lr0.001
encoder_embed_dim256
decoder_embed_dim192
max_len64
in_out_channels3
attn_dim128
num_heads8
mlp_size1024
encoder_depth12
decoder_depth8
predictor_depth8
predictor_embed_dim192
num_preds5
predictor_window_size5
residualtrue
use_maskstrue
use_bboxesfalse
InΒ [4]:
from model.oc_encoder import ObjectCentricEncoder
from model.oc_decoder import ObjectCentricDecoder
from model.ocvp import TransformerAutoEncoder
InΒ [8]:
oc_encoder = ObjectCentricEncoder()
oc_decoder = ObjectCentricDecoder()

model = TransformerAutoEncoder(oc_encoder, oc_decoder).to(device)

# summary(model, input_size= rgbs.shape)
InΒ [6]:
with torch.no_grad():
    encoded_features = oc_encoder(rgbs, masks, coords)
    print("Encoded features shape:", encoded_features.shape)
    recons, loss = oc_decoder(encoded_features, rgbs)
    print("Reconstructed output shape:", recons.shape)
Encoded features shape: torch.Size([32, 24, 11, 256])
Reconstructed output shape: torch.Size([32, 24, 3, 64, 64])

We can see that our cnn encoder is properly producing 11 embedding for all the objects in the scene, and also decoder reconstructs the input to the same size

- Loading Onject-Centric-AE model pre-trained checkpointsΒΆ

InΒ [10]:
from utils.visualization import plot_images_vs_recons

path_AE = 'experiments/01_OC_AE_XL_64_Full_CNN/checkpoints/best_01_OC_AE_XL_64_Full_CNN.pth'
model = load_model(model, mode='AE_inference', path_AE = path_AE)

# Generate reconstructions
with torch.no_grad():
    recons, _ = model(rgbs,masks)

# Plot 5 random sequences, each showing 8 original vs 8 reconstructed frames
plot_images_vs_recons(rgbs, recons)
No description has been provided for this image
- PSNR, SSIM, and LPIPS metrics analysis:ΒΆ
InΒ [Β ]:
from utils.metrics import evaluate_metrics

evaluate_metrics(recons, rgbs)
PSNR Mean: 24.983428955078125
PSNR Framewise: tensor([25.6122, 25.3229, 24.9899, 24.8046, 24.5098, 24.4270, 24.4462, 24.5905,
        24.6441, 24.6562, 24.6746, 24.7518, 24.8696, 24.9073, 24.9973, 25.0703,
        25.1042, 25.1092, 25.1894, 25.2762, 25.3370, 25.4012, 25.4492, 25.4616],
       device='cuda:0')
SSIM Mean: 0.7169415354728699
SSIM Framewise: tensor([0.7294, 0.7254, 0.7211, 0.7150, 0.7064, 0.7033, 0.7019, 0.7087, 0.7133,
        0.7163, 0.7199, 0.7191, 0.7202, 0.7197, 0.7205, 0.7200, 0.7193, 0.7177,
        0.7189, 0.7175, 0.7160, 0.7175, 0.7201, 0.7194], device='cuda:0')
LPIPS Mean: 0.24213504791259766
LPIPS Framewise: tensor([0.2474, 0.2425, 0.2369, 0.2456, 0.2531, 0.2534, 0.2601, 0.2444, 0.2392,
        0.2403, 0.2391, 0.2362, 0.2336, 0.2404, 0.2380, 0.2475, 0.2408, 0.2410,
        0.2409, 0.2388, 0.2419, 0.2406, 0.2374, 0.2321], device='cuda:0')

NOTEΒΆ

Not too bad!! Not the best! Could be better. We proceeded to the predictor training with this model

  • Object-Centric Transformer-Predictor Module
InΒ [17]:
from model.oc_predictor import ObjectCentricTransformerPredictor
from model.predictor_wrapper import PredictorWrapper
from model.ocvp import TransformerPredictor

oc_predictor = ObjectCentricTransformerPredictor()
predictor = PredictorWrapper(oc_predictor)
model = TransformerPredictor(oc_encoder, oc_decoder, predictor, mode='inference').to(device)
InΒ [25]:
from utils.visualization import plot_predictor_images

path_AE = 'experiments/01_OC_AE_XL_64_Full_CNN/checkpoints/best_01_OC_AE_XL_64_Full_CNN.pth'
path_predictor = 'experiments/01_OC_Predictor_XL/checkpoints/best_01_OC_Predictor_XL.pth'

model = load_model(model, mode='inference', path_AE = path_AE, path_predictor=path_predictor)

# we Visualize 3 random sequences
idx = [0,3,7]

with torch.no_grad():
    for i in idx:
        # Get a single sequence
        sequence_rgbs = rgbs[i:i+1]  # This keeps the batch dimension
        sequence_masks = {k: v[i:i+1] for k, v in masks.items()}  # Handle masks dictionary
        
        # Get predictions
        encoded_features = oc_encoder(sequence_rgbs, sequence_masks)
        preds, loss, input_range, target_range = model(sequence_rgbs, sequence_masks)
        
        # Get reconstructions
        recons, _ = model.decoder(preds)
        
        # Get input and target ranges
        start_input, end_input = input_range
        start_target, end_target = target_range
        
        # Extract relevant frames
        input_images = sequence_rgbs[:, start_input:end_input]
        target_images = sequence_rgbs[:, start_target:end_target]
        
        # Visualize
        plot_predictor_images(input_images, target_images, recons)
        
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

NOTEΒΆ

We can see that our OC-Predictor could actually produce much better future scene predictions, compared to the Holistic predictor module. This is due to the fact that in OC scene scenario, the model benefits from the object-frames extracted from masks/bboxes and through multi-head attention, we can actually capture the temporal relations between differnt objects in the scene.